import torch
import torch.nn as nn
import numpy as np
import heapq
import utils
import torch.jit


softmax_func = nn.Softmax(-1)
softmax_funcfinal = nn.Softmax(-1)


def min_smooth(x1, x2, beta: torch.Tensor = torch.tensor(1.0)):
    M = torch.minimum(x1*beta, x2*beta)
    return (1/beta)*M - (1/beta)*torch.log(torch.exp(-x1*beta + M) + torch.exp(-x2*beta + M))

def argmin_smooth(x1, x2, beta: torch.Tensor = torch.tensor(1.0)):
    return softmax_func(torch.stack((-beta*x1, -beta*x2), dim=-1))


def floyd_warshall(weight_matrix):
    num_vertices = weight_matrix.shape[0]

    shortest_distances = weight_matrix

    # Iterate over all possible paths in the graph
    for k in range(num_vertices):
        for i in range(num_vertices):
            for j in range(num_vertices):
                shortest_distances[i, j] = min_smooth(shortest_distances[i, j], 
                                                      shortest_distances[i, k] + shortest_distances[k, j],
                                                      10000.)         
    return shortest_distances


def floyd_warshall_batch(weight_matrix):    
    ss = weight_matrix.shape[0]
    bs = weight_matrix.shape[1]
    num_vertices = weight_matrix.shape[2]

    shortest_distances = weight_matrix

    # Iterate over all possible paths in the graph
    for k in range(num_vertices):
        for i in range(num_vertices):
            for j in range(num_vertices):
                shortest_distances[:, :, i, j] = min_smooth(shortest_distances[:, :, i, j], 
                                                      shortest_distances[:, :, i, k] + shortest_distances[:, :, k, j],
                                                      10000.)         
    return shortest_distances


def smooth_floyd_warshall(weight_matrix, prior_matrix, beta_smooth=1.):
    num_vertices = weight_matrix.shape[0]

    shortest_distances = weight_matrix.clone()

    for k in range(num_vertices):
        for i in range(num_vertices):
            for j in range(num_vertices):
                shortest_distances[i, j] = torch.inner(
                    argmin_smooth(shortest_distances[i, j], 
                                  shortest_distances[i, k] + shortest_distances[k, j], 
                                  beta=beta_smooth),
                    torch.hstack((prior_matrix[i, j], prior_matrix[i, k] + prior_matrix[k, j]))
                )
                
                prior_matrix[i, j] = torch.inner(
                    argmin_smooth(prior_matrix[i, j], 
                                  prior_matrix[i, k] + prior_matrix[k, j], 
                                  beta=beta_smooth),
                    torch.hstack((prior_matrix[i, j], prior_matrix[i, k] + prior_matrix[k, j]))
                )               
    return shortest_distances, prior_matrix


def d_is_large(d):
    return d>2500.


def smooth_floyd_warshall_batch(weight_matrix, prior_matrix, idx_paths_matrices, beta_smooth: torch.Tensor = torch.tensor(1.0)):
        
    #ss = weight_matrix.shape[0] #Sampling size
    bs = weight_matrix.shape[1] #Batch size
    num_vertices = weight_matrix.shape[2]

    shortest_distances = weight_matrix

    assert prior_matrix.dim() == 3
    shortest_distances_prior = prior_matrix.unsqueeze(1).repeat(1, bs, 1, 1)
           
    for k in range(num_vertices):
        for i in range(num_vertices):
            if i == k:
                continue
            for j in range(num_vertices):                                
                if j == i or j==k:
                    continue
                if (d_is_large(shortest_distances[:, :, i, k] + shortest_distances[:, :, k, j])).all():
                    continue         
                argmins = argmin_smooth(shortest_distances[:, :, i, j], 
                                  shortest_distances[:, :, i, k] + shortest_distances[:, :, k, j], 
                                  beta=beta_smooth)
                vals = torch.stack((shortest_distances[:, :, i, j], 
                                    shortest_distances[:, :, i, k] + shortest_distances[:, :, k, j]), dim=2)
                
                new_shortest_distances = (vals*argmins).sum(2)
                shortest_distances = shortest_distances.clone()
                shortest_distances[:, :, i, j] = new_shortest_distances
                
                priors = torch.stack((shortest_distances_prior[:, :, i, j], 
                                       shortest_distances_prior[:, :, i, k] + shortest_distances_prior[:, :, k, j]), dim=2)

                new_shortest_distances_prior = (priors*argmins).sum(2)
                shortest_distances_prior = shortest_distances_prior.clone()
                shortest_distances_prior[:, :, i, j] = new_shortest_distances_prior  
    return shortest_distances, shortest_distances_prior



@torch.jit.script
def smooth_floyd_warshall_batch_adapted(weight_matrix, M_indices, 
                                        beta_smooth: torch.Tensor = torch.tensor(1.0)):
    
    bs = weight_matrix.shape[0] #Batch size
    num_vertices = weight_matrix.shape[1] # #Nodes  
         
    shortest_distances = weight_matrix   
    
    argmins_tensor = torch.zeros((bs, num_vertices, num_vertices, num_vertices)) 
    
    argmins_tensor[:, M_indices[:,0], M_indices[:,1], M_indices[:,0]] = 1.
    
    for k in range(0, num_vertices):
        for i in range(0, num_vertices):
            for j in range(0, num_vertices):

                # Cases where we don't need to go through (save time)
                if i == k or j == i or j==k:
                    continue

                shortcut = shortest_distances[:, i, k] + shortest_distances[:, k, j]
                direct = shortest_distances[:, i, j]

                if (d_is_large(shortcut)).all():
                    continue

                argmins = argmin_smooth(direct, shortcut, beta=beta_smooth)

                argmins_tensor_clone = argmins_tensor[:, i, j, :].clone() 
               
                argmins_tensor[:,i,j,:] = argmins_tensor_clone*(argmins[:,0].unsqueeze(-1))
                argmins_tensor[:,i,j,k] = argmins[:,1] 

                vals = torch.stack((direct, shortcut), dim=1)

                new_shortest_distances = (vals*argmins).sum(1)

                shortest_distances = shortest_distances.clone()
                shortest_distances[:, i, j] = new_shortest_distances
            
    return argmins_tensor



#@torch.jit.script
def smooth_floyd_warshall_batch_adapted_parallel(weight_matrix, M_indices, dev: str,
                                                 beta_smooth: torch.Tensor = torch.tensor(1.0)):
    
    bs = weight_matrix.shape[0]  # Batch size
    num_vertices = weight_matrix.shape[1]  # #Nodes

    shortest_distances = weight_matrix
    argmins_tensor = torch.zeros((bs, num_vertices, num_vertices, num_vertices), dtype=torch.float32).to(dev)
    argmins_tensor[:, M_indices[:, 0], M_indices[:, 1], M_indices[:, 0]] = 1.
    
    for k in range(num_vertices):
        # Broadcasting to create a matrix of shape (bs, num_vertices, num_vertices)
        shortcut = shortest_distances[:, :, k, None] + shortest_distances[:, k, None, :]
        direct = shortest_distances[:, :, :]
        
        # Avoiding operations where i == k, j == i, or j == k
        i_mask = torch.ones(num_vertices, dtype=torch.bool).to(weight_matrix.device)
        i_mask[k] = False  # i != k
        j_mask = torch.ones(num_vertices, dtype=torch.bool).to(weight_matrix.device)
        j_mask[k] = False  # j != k

        mask = i_mask[:, None] & j_mask[None, :]  # Combining masks for i and j
        
        mask = mask[None, :, :].expand(bs, num_vertices, num_vertices) 
        
        large_mask = ~d_is_large(shortcut)
        mask = mask & large_mask

        argmins = argmin_smooth(direct, shortcut, beta=beta_smooth)
        
        argmins_tensor_clone = argmins_tensor.clone()
        
        argmins_tensor = torch.where(mask.unsqueeze(-1), argmins_tensor_clone * argmins[:,:,:,0].unsqueeze(-1), argmins_tensor)
        argmins_tensor[:, :, :, k] = torch.where(mask, argmins[:,:,:,1], argmins_tensor[:, :, :, k])
        
        new_shortest_distances = min_smooth(direct, shortcut, beta=beta_smooth)
        
        shortest_distances = torch.where(mask, new_shortest_distances, shortest_distances)
            
    return argmins_tensor




from torch.distributions import Normal
from torch import jit
from math import sqrt, pi

@jit.script
def normal_cdf(x):
    """Cumulative distribution function for the standard normal distribution."""
    return 0.5 * (1 + torch.erf(x / sqrt(2.0)))

@jit.script
def probabilistic_argmin(a, b, sa, sb):
    z = (a - b)/torch.sqrt(sa**2 + sb**2)
    argmins = 1 - normal_cdf(z)
    return argmins

def exclude_row_col(tensor, index):
    rows_excluded = torch.cat((tensor[:,:index], tensor[:,index + 1:]), dim=1)
    result = torch.cat((rows_excluded[:,:, :index], rows_excluded[:,:, index + 1:]), dim=2)
    return result

@jit.script
def remove_node_and_adjust_vectorized(adjacency_matrix, node, M_indices, 
                                      beta_smooth: torch.Tensor = torch.tensor(1.0)):
    
    mask = ~(M_indices[:, 0] == node) & ~(M_indices[:, 1] == node) & ~(M_indices[:, 0] == M_indices[:, 1])
    valid_edges = M_indices[mask]

    i_indices = valid_edges[:, 0]
    j_indices = valid_edges[:, 1]

    shortcut = adjacency_matrix[:, :, node][:, i_indices] + adjacency_matrix[:, node, :][:, j_indices]
    direct = adjacency_matrix[:, i_indices, j_indices]

    argmins = argmin_smooth(direct, shortcut, beta=beta_smooth)
    
    adj_best = adjacency_matrix.clone()
    
    adj_best[:, i_indices, j_indices] = min_smooth(direct, shortcut, beta=beta_smooth)
    
    return adj_best

@jit.script
def remove_node_and_adjust(adjacency_matrix, sigma_matrix, node, M_indices):
    adj_best = adjacency_matrix.clone()
    for edge in M_indices:
        i, j = edge[0], edge[1]
        
        if i==node or j==node or i==j:
            continue
        
        shortcut = adjacency_matrix[:,:,node][:,i] + adjacency_matrix[:,node,:][:,j]
              
        if (shortcut>2500.0).all():
            continue
            
        direct = adjacency_matrix[:,i,j]  
        
        sigma_shortcut = torch.sqrt(sigma_matrix[:,:,node][:,i]**2 + sigma_matrix[:,node,:][:,j]**2)
        
        argmins = probabilistic_argmin(direct, shortcut, sigma_matrix[:,i,j], sigma_shortcut)
        
        adj_best[:,i, j] = direct*argmins + shortcut*(1-argmins)
    
    return adj_best





import heapq
from multiprocessing import Pool, cpu_count

from scipy.sparse.csgraph import dijkstra as dij
from scipy.sparse import csr_matrix

def dijkstra(adj_matrix_and_indices):
    adjacency_matrix, start_node, end_node, matrix_index = adj_matrix_and_indices
    graph = csr_matrix(adjacency_matrix)
    distances, predecessors = dij(csgraph=graph, 
                                  directed=True, 
                                  indices=start_node, 
                                  return_predecessors=True)
    path = [end_node]    
    while path[-1] != start_node:
        path.append(predecessors[path[-1]])       
    path.reverse()    
    return matrix_index, path

def batch_dijkstra(adjacency_matrices, node_pairs):
    """
    :param adjacency_matrices: Batch of adjacency matrices of shape (B, V, V)
    :param node_pairs: Matrix of shape (B, 2) where first column is start node, second is end node
    """
    with Pool(cpu_count()) as pool:
        input_data = [(adjacency_matrices[i], node_pairs[i][0], node_pairs[i][1], i) for i in range(len(adjacency_matrices))]
        results = pool.map(dijkstra, input_data)    
    results.sort(key=lambda x: x[0])
    return [path for _, path in results]


def get_optimal_path_matrix(adjacency_matrix, start_node, end_node):
    path = dijkstra(adjacency_matrix, start_node, end_node)
    optimal_path_matrix = np.zeros_like(adjacency_matrix)
    for i in range(len(path) - 1):
        optimal_path_matrix[path[i], path[i + 1]] = 1
    return optimal_path_matrix


def edges_on_pred_func(best_pred_path, M_indices):
    edges_on_pred = np.zeros((len(best_pred_path), M_indices.shape[0]))
    for i in range(0, len(best_pred_path)):
        edges_sequence_pred_i = np.column_stack([best_pred_path[i][:-1], best_pred_path[i][1:]])
        edges_on_pred[i] = np.array([1 if any(np.array_equal(edge, t) for t in edges_sequence_pred_i) else 0 for edge in M_indices.detach().numpy()])
    return edges_on_pred


# Code below based on Vlastelica 2019 (Differentiable Blackbox Combinatorial Solver)
class SolverDiff(torch.autograd.Function):
    lambda_val = None
    prior_M = None
    M_indices = None

    @classmethod
    def set_parameters(cls, lambda_val, prior_M, M_indices):
        cls.lambda_val = lambda_val
        cls.prior_M = prior_M
        cls.M_indices = M_indices

    @staticmethod
    def forward(ctx, dY, end_to_end_nodes_batch):
        ctx.Y = dY.detach().numpy()
        ctx.end_to_end_nodes_batch = end_to_end_nodes_batch
        
        M = utils.costs_to_matrix(SolverDiff.prior_M, SolverDiff.M_indices, dY)
        best_pred_paths = batch_dijkstra(M, end_to_end_nodes_batch)
        edges_on_pred_np = edges_on_pred_func(best_pred_paths, SolverDiff.M_indices)
        
        ctx.edges_on_pred_np = edges_on_pred_np
        return torch.tensor(edges_on_pred_np, dtype=torch.float32)

    @staticmethod
    def backward(ctx, grad_output):
        grad_output_numpy = grad_output.detach().numpy()
        Y_prime = ctx.Y + SolverDiff.lambda_val * grad_output_numpy

        M_prime = utils.costs_to_matrix(SolverDiff.prior_M, SolverDiff.M_indices, Y_prime)
        best_pred_paths_prime = batch_dijkstra(M_prime, ctx.end_to_end_nodes_batch)
        edges_on_pred_prime_np = edges_on_pred_func(best_pred_paths_prime, SolverDiff.M_indices)

        gradient = -(ctx.edges_on_pred_np - edges_on_pred_prime_np) / SolverDiff.lambda_val
        return torch.from_numpy(gradient.astype(np.float32)), None